defmodule Bumblebee.Text.Generation do
  @moduledoc """
  An interface for language models supporting sequence generation.
  """

  @type cache :: Nx.Tensor.t() | Nx.Container.t()

  @doc """
  Initializes an opaque cache input for iterative inference.
  """
  @callback init_cache(
              spec :: Bumblebee.ModelSpec.t(),
              batch_size :: pos_integer(),
              max_length :: pos_integer(),
              inputs :: map()
            ) :: cache()

  @doc """
  Traverses all batched tensors in the cache.

  This function is used when the cache needs to be inflated or
  deflated for a different batch size.
  """
  @callback traverse_cache(
              spec :: Bumblebee.ModelSpec.t(),
              cache(),
              (Nx.Tensor.t() -> Nx.Tensor.t())
            ) :: cache()

  @doc """
  Returns a configuration module for extra model-specific generation
  attributes to extend the base `Bumblebee.Text.GenerationConfig`.
  """
  @callback extra_config_module(spec :: Bumblebee.ModelSpec.t()) :: module()

  @optional_callbacks extra_config_module: 1

  import Nx.Defn

  alias Bumblebee.Utils

  @doc """
  Initializes an opaque cache input for iterative inference.
  """
  @spec init_cache(Bumblebee.ModelSpec.t(), pos_integer(), pos_integer(), map()) :: cache()
  def init_cache(%module{} = spec, batch_size, max_length, inputs) do
    module.init_cache(spec, batch_size, max_length, inputs)
  end

  @doc """
  Calls `fun` for every batched tensor in the cache.
  """
  @spec traverse_cache(
          Bumblebee.ModelSpec.t(),
          cache,
          (Nx.Tensor.t() -> Nx.Tensor.t())
        ) :: cache()
  def traverse_cache(%module{} = spec, cache, fun) do
    module.traverse_cache(spec, cache, fun)
  end

  @doc """
  Returns a configuration module for extra model-specific generation
  attributes to extend the base `Bumblebee.Text.GenerationConfig`.
  """
  @spec extra_config_module(Bumblebee.ModelSpec.t()) :: module() | nil
  def extra_config_module(%module{} = spec) do
    if Code.ensure_loaded?(module) and function_exported?(module, :extra_config_module, 1) do
      module.extra_config_module(spec)
    end
  end

  @doc """
  Builds a numerical definition that generates sequences of tokens using
  the given language model.

  The model should be either a decoder or an encoder-decoder. The tokens
  are generated by iterative inference using the decoder (autoregression),
  until the termination criteria are met.

  In case of encoder-decoder models, the corresponding encoder is run
  only once and the intermediate state is reused during all iterations.

  The generation is controlled by a number of options given as
  `%Bumblebee.Text.GenerationConfig{}`, see the corresponding docs
  for more details.

  Returns a defn JIT-compatible anonymous function, which expects the
  model params as the first argument and inputs map as the second
  argument. Note that the inputs map should additionally include a
  `"seed"` tensor, with one value per input in the batch.

  ## Streaming

  This function sets up a hook that is invoked after every generated
  token. The hook receives a map with the following attributes:

    * `:token_id` - the newly generated token

    * `:finished?` - a boolean indicating if the sequence is finished

    * `:length` - the current length of the generated sequence. Once
      the sequence is finished, the length does not increase

  Each of the attributes is a tensor with a leading batch dimension.

  When streaming you may not care about the output result, in which
  case you can enable `:ignore_output` to reduce the output size.

  ## Options

    * `:logits_processors` - a list of numerical functions to modify
      predicted scores at each generation step. The functions are
      applied in order, after all default processors

    * `:ignore_output` - if true, returns a dummy tensor that should
      be ignored. This is useful when you consume the generated tokens
      in a stream fashion via the hook, so that the full output does
      not need to be transferred unnecessarily after the computation.
      Defaults to `false`

  """
  @spec build_generate(
          Axon.t(),
          Bumblebee.ModelSpec.t(),
          Bumblebee.Text.GenerationConfig.t(),
          keyword()
        ) ::
          (params :: map(), inputs :: map() ->
             %{token_ids: Nx.Tensor.t(), length: Nx.Tensor.t()} | (ignored :: Nx.Tensor.t()))
  def build_generate(model, spec, config, opts \\ []) do
    opts = Keyword.validate!(opts, logits_processors: [], ignore_output: false)

    decoder_start_token_id = config.decoder_start_token_id || config.bos_token_id
    eos_token_id = config.eos_token_id
    pad_token_id = config.pad_token_id || config.eos_token_id

    unless pad_token_id do
      raise ArgumentError,
            "expected :pad_token_id (or :eos_token_id as a fallback) to be present in" <>
              " generation config, but it was not set. Make sure to load a complete" <>
              " generation config or update it accordingly using Bumblebee.configure/2"
    end

    {max_length_fun, min_length_fun} = lazy_lengths_from_opts(config)

    {prepare_inputs_fun, update_inputs_fun} =
      input_callbacks(model, spec, max_length_fun, decoder_start_token_id)

    traverse_cache_fun = &traverse_cache(spec, &1, &2)

    global_layer_options =
      if config.strategy.type == :contrastive_search do
        [output_hidden_states: true]
      else
        []
      end

    {_init_fun, predict_fun} = Axon.build(model, global_layer_options: global_layer_options)

    logits_processor_fun = get_logits_processor(min_length_fun, config, opts[:logits_processors])

    &generate_impl(
      &2,
      predict_fun,
      &1,
      logits_processor_fun,
      prepare_inputs_fun,
      update_inputs_fun,
      traverse_cache_fun,
      pad_token_id: pad_token_id,
      eos_token_id: eos_token_id,
      strategy: config.strategy,
      ignore_output: opts[:ignore_output]
    )
  end

  defp lazy_lengths_from_opts(opts) do
    max_length_fun =
      case {opts.max_new_tokens, opts.max_length} do
        {nil, nil} ->
          raise ArgumentError,
                "expected either :max_new_tokens or :max_length option, but neither was given"

        {max_new_tokens, nil} ->
          fn input_length -> input_length + max_new_tokens end

        {nil, max_length} ->
          fn _ -> max_length end
      end

    min_length_fun =
      case {opts.min_new_tokens, opts.min_length} do
        {nil, nil} ->
          nil

        {min_new_tokens, nil} ->
          fn input_length -> input_length + min_new_tokens end

        {nil, min_length} ->
          fn _ -> min_length end
      end

    {max_length_fun, min_length_fun}
  end

  defp encoder_from_encoder_decoder(model) do
    # We cherry-pick encoder outputs from the encoder-decoder outputs.
    # The expanded expression will have no decoder bits, so it will
    # effectively be the same as an encoder built from scratch

    Axon.nx(model, fn outputs ->
      case outputs do
        %{
          encoder_hidden_state: hidden_state,
          encoder_hidden_states: hidden_states,
          encoder_attentions: attentions
        } ->
          %{
            hidden_state: hidden_state,
            hidden_states: hidden_states,
            attentions: attentions
          }

        _ ->
          raise ArgumentError,
                "expected an encoder-decoder model, but it does not have the expected outputs"
      end
    end)
  end

  defp input_callbacks(model, spec, max_length_fun, decoder_start_token_id) do
    if encoder_decoder?(model) do
      encoder = encoder_from_encoder_decoder(model)
      {_encoder_init_fun, encoder_predict_fun} = Axon.build(encoder)

      prepare_inputs_fun = fn inputs, params ->
        encoder_outputs = encoder_predict_fun.(params, inputs)

        batch_size = Nx.axis_size(encoder_input(inputs), 0)
        decoder_input_ids = Nx.broadcast(decoder_start_token_id, {batch_size, 1})

        inputs =
          Map.merge(inputs, %{
            "encoder_hidden_state" => encoder_outputs.hidden_state,
            "decoder_input_ids" => decoder_input_ids
          })

        max_length = max_length_fun.(1)
        inputs = prepare_decoder_inputs(inputs, "decoder_", spec, model, max_length)
        {inputs, inputs["decoder_input_ids"], max_length}
      end

      update_inputs_fun = &update_decoder_inputs("decoder_", &1, &2, &3)

      {prepare_inputs_fun, update_inputs_fun}
    else
      prepare_inputs_fun = fn inputs, _params ->
        sequence_length = Nx.axis_size(inputs["input_ids"], 1)
        max_length = max_length_fun.(sequence_length)
        inputs = prepare_decoder_inputs(inputs, "", spec, model, max_length)
        {inputs, inputs["input_ids"], max_length}
      end

      update_inputs_fun = &update_decoder_inputs("", &1, &2, &3)

      {prepare_inputs_fun, update_inputs_fun}
    end
  end

  defp encoder_decoder?(model) do
    inputs = Axon.get_inputs(model)
    encoder_input(inputs) != nil and Map.has_key?(inputs, "decoder_input_ids")
  end

  defp encoder_input(inputs) do
    inputs["input_ids"] || inputs["input_features"] || inputs["pixel_values"]
  end

  defp prepare_decoder_inputs(inputs, prefix, spec, model, max_length) do
    input_ids = inputs[prefix <> "input_ids"]
    attention_mask = inputs[prefix <> "attention_mask"] || Nx.broadcast(1, input_ids)

    position_ids =
      attention_mask
      |> Nx.cumulative_sum(axis: 1)
      |> Nx.subtract(1)

    inputs =
      inputs
      |> Map.put(prefix <> "attention_mask", attention_mask)
      |> Map.put(prefix <> "position_ids", position_ids)

    batch_size = Nx.axis_size(input_ids, 0)
    cache = init_cache(spec, batch_size, max_length, inputs)

    output_policy = model_output_policy(model)

    # TODO: fix Axon.MixedPrecision.cast/2 to not cast integers, to
    # match Axon compiler

    # Cast all float cache tensors to match the model output. This way
    # we make sure the cache we pass as input has the same types as
    # the updated cache returned from the model
    cache =
      Bumblebee.Utils.Nx.map(cache, fn tensor ->
        if Nx.Type.integer?(Nx.type(tensor)) do
          tensor
        else
          Axon.MixedPrecision.cast(output_policy, tensor, :output)
        end
      end)

    Map.put(inputs, "cache", cache)
  end

  defp model_output_policy(model) do
    {node, _} = Axon.pop_node(model)
    node.policy
  end

  defp update_decoder_inputs(prefix, inputs, cache, token_ids) do
    inputs
    |> Map.replace!(prefix <> "input_ids", token_ids)
    |> Map.replace!(prefix <> "attention_mask", Nx.broadcast(1, token_ids))
    |> Map.update!(prefix <> "position_ids", fn position_ids ->
      position_ids
      |> Nx.slice_along_axis(Nx.axis_size(position_ids, -1) - 1, 1, axis: -1)
      |> Nx.add(1)
    end)
    |> Map.replace!("cache", cache)
  end

  defp get_logits_processor(min_length_fun, config, logits_processors) do
    import Bumblebee.Text.Generation.LogitsProcessing

    processors =
      [
        if config.no_repeat_ngram_length && config.no_repeat_ngram_length > 0 do
          &no_repeat_ngram_processor(&1, &2, ngram_length: config.no_repeat_ngram_length)
        end,
        if min_length_fun && config.eos_token_id do
          &min_length_processor(&1, &2,
            min_length_fun: min_length_fun,
            eos_token_id: config.eos_token_id
          )
        end,
        if config.suppressed_token_ids != [] do
          &suppressed_tokens_processor(&1, &2, suppressed_token_ids: config.suppressed_token_ids)
        end,
        if config.forced_bos_token_id do
          &bos_token_processor(&1, &2, bos_token_id: config.forced_bos_token_id)
        end,
        if config.forced_eos_token_id do
          &eos_token_processor(&1, &2, eos_token_id: config.forced_eos_token_id)
        end,
        if config.forced_token_ids do
          &forced_tokens_processor(&1, &2, forced_token_ids: config.forced_token_ids)
        end,
        if config.temperature && config.temperature != 1.0 do
          &temperature_processor(&1, &2, temperature: config.temperature)
        end
      ] ++
        if config.strategy.type == :multinomial_sampling do
          [
            if top_k = config.strategy[:top_k] do
              &top_k_processor(&1, &2, top_k: top_k)
            end,
            if top_p = config.strategy[:top_p] do
              &top_p_processor(&1, &2, top_p: top_p)
            end
          ]
        else
          []
        end ++ logits_processors

    fn logits, context ->
      for processor <- processors, processor, reduce: logits do
        logits -> processor.(logits, context)
      end
    end
  end

  defnp generate_impl(
          inputs,
          predict_fun,
          params,
          logits_processor_fun,
          prepare_inputs_fun,
          update_inputs_fun,
          traverse_cache_fun,
          opts \\ []
        ) do
    {seed, inputs} = pop_seed(inputs)

    {decoder_inputs, decoder_input_ids, max_length} = prepare_inputs_fun.(inputs, params)

    length = Nx.axis_size(decoder_input_ids, 1)

    if length >= max_length do
      raise ArgumentError,
            "the input sequence has #{length} tokens, but max_length is set to #{max_length}." <>
              " Consider increasing :max_new_tokens (or :max_length), or padding the input to a shorter length"
    end

    strategy = opts[:strategy]

    state =
      case strategy.type do
        :greedy_search ->
          greedy(
            decoder_inputs,
            decoder_input_ids,
            predict_fun,
            params,
            logits_processor_fun,
            update_inputs_fun,
            merge_options([max_length: max_length], opts)
          )

        :contrastive_search ->
          contrastive(
            decoder_inputs,
            decoder_input_ids,
            predict_fun,
            params,
            logits_processor_fun,
            update_inputs_fun,
            traverse_cache_fun,
            merge_options(
              [max_length: max_length, top_k: strategy.top_k, penalty_alpha: strategy.alpha],
              opts
            )
          )

        :multinomial_sampling ->
          sampling(
            decoder_inputs,
            decoder_input_ids,
            predict_fun,
            params,
            seed,
            logits_processor_fun,
            update_inputs_fun,
            merge_options([max_length: max_length], opts)
          )
      end

    if opts[:ignore_output] do
      state.ignored
    else
      %{
        # Output only the newly generated tokens
        token_ids: state.sequences[[.., length..-1//1]],
        length: state.finished_length - length
      }
    end
  end

  deftransformp pop_seed(inputs), do: Map.pop!(inputs, "seed")

  deftransformp merge_options(left, right), do: left ++ right

  # Greedy search

  defnp greedy(
          inputs,
          decoder_input_ids,
          predict_fun,
          params,
          logits_processor_fun,
          update_inputs_fun,
          opts \\ []
        ) do
    max_length = opts[:max_length]
    pad_token_id = opts[:pad_token_id]
    eos_token_id = opts[:eos_token_id]

    state = init_sequences(decoder_input_ids, max_length, pad_token_id)

    # The loop works with inputs of length 1, so if the initial input
    # is longer, we make the initial pass outside
    {state, inputs} =
      if state.length > 1 do
        greedy_step(
          state,
          inputs,
          predict_fun,
          params,
          logits_processor_fun,
          update_inputs_fun,
          pad_token_id: pad_token_id,
          eos_token_id: eos_token_id
        )
      else
        {state, inputs}
      end

    {state, _inputs, _params} =
      while {state, inputs, params}, continue?(state.finished_length) do
        {state, inputs} =
          greedy_step(
            state,
            inputs,
            predict_fun,
            params,
            logits_processor_fun,
            update_inputs_fun,
            pad_token_id: pad_token_id,
            eos_token_id: eos_token_id
          )

        {state, inputs, params}
      end

    state
  end

  defnp init_sequences(decoder_input_ids, max_length, pad_token_id) do
    {batch_size, length} = Nx.shape(decoder_input_ids)

    sequences = Nx.broadcast(pad_token_id, {batch_size, max_length})
    sequences = Nx.put_slice(sequences, [0, 0], decoder_input_ids)

    # For each sequence, we keep track of its final length, where 0
    # means that it has not been finished yet
    finished_length = Nx.broadcast(0, {batch_size})

    %{
      sequences: sequences,
      input_length: length,
      length: length,
      finished_length: finished_length,
      # The ignored return value that we attach all hooks to
      ignored: Nx.broadcast(0, {batch_size})
    }
  end

  defnp continue?(finished_length) do
    Nx.any(finished_length == 0)
  end

  defnp greedy_step(
          state,
          inputs,
          predict_fun,
          params,
          logits_processor_fun,
          update_inputs_fun,
          opts
        ) do
    pad_token_id = opts[:pad_token_id]
    eos_token_id = opts[:eos_token_id]

    outputs = predict_fun.(params, inputs)

    logits = outputs.logits[[.., -1]]
    logits = batch_process_logits(logits_processor_fun, logits, state)
    token_id = Nx.argmax(logits, axis: -1)

    state = update_sequences(state, token_id, pad_token_id, eos_token_id)

    inputs = update_inputs_fun.(inputs, outputs.cache, Nx.new_axis(token_id, -1))

    {state, inputs}
  end

  defnp update_sequences(state, token_id, pad_token_id, eos_token_id) do
    %{
      sequences: sequences,
      length: length,
      input_length: input_length,
      finished_length: finished_length
    } = state

    token_id = Nx.select(finished_length > 0, pad_token_id, token_id)

    token_ids = Nx.new_axis(token_id, -1)
    sequences = Nx.put_slice(sequences, [0, length], token_ids)
    length = length + 1

    {batch_size, max_length} = Nx.shape(sequences)

    finished_length =
      case eos_token_id do
        nil ->
          finished_length

        eos_token_id ->
          Nx.select(
            finished_length == 0 and (token_id == eos_token_id or length == max_length),
            length,
            finished_length
          )
      end

    finished? = finished_length > 0
    output_length = Nx.broadcast(length - input_length, {batch_size})
    data = %{token_id: token_id, finished?: finished?, length: output_length}
    token = create_token()
    {token, _} = hook_token(token, data, :token)

    state = %{state | sequences: sequences, length: length, finished_length: finished_length}
    attach_token(token, state)
  end

  defnp batch_process_logits(logits_processor_fun, logits, state) do
    logits
    |> Nx.vectorize(:batch)
    |> logits_processor_fun.(%{
      sequence: Nx.vectorize(state.sequences, :batch),
      length: state.length,
      input_length: state.input_length
    })
    |> Nx.devectorize(keep_names: false)
  end

  # Contrastive search

  defnp contrastive(
          inputs,
          decoder_input_ids,
          predict_fun,
          params,
          logits_processor_fun,
          update_inputs_fun,
          traverse_cache_fun,
          opts \\ []
        ) do
    max_length = opts[:max_length]
    pad_token_id = opts[:pad_token_id]
    eos_token_id = opts[:eos_token_id]
    top_k = opts[:top_k]
    penalty_alpha = opts[:penalty_alpha]

    state = init_sequences(decoder_input_ids, max_length, pad_token_id)

    # Step (1)
    # Initial pass to obtain hidden state and expand inputs to top-k

    outputs = predict_fun.(params, inputs)

    # Later, we feed model a single token at a time and reuse previous
    # results using cache. Here we need the final hidden state, so we
    # need to keep track of it in a similar way
    initial_hidden_state = decoder_hidden_state(outputs)
    batch_size = Nx.axis_size(initial_hidden_state, 0)
    hidden_size = Nx.axis_size(initial_hidden_state, -1)

    joint_hidden_state =
      Nx.broadcast(
        Nx.tensor(0, type: Nx.type(initial_hidden_state)),
        {batch_size, max_length, hidden_size}
      )

    joint_hidden_state = Nx.put_slice(joint_hidden_state, [0, 0, 0], initial_hidden_state)

    logits = outputs.logits[[.., -1]]
    logits = batch_process_logits(logits_processor_fun, logits, state)
    scores = Axon.Activations.softmax(logits, axis: -1)
    {top_k_scores, top_k_token_ids} = Nx.top_k(scores, k: top_k)

    # For subsequent model passes we consider several (top-k) paths
    # for each batch item, so we duplicate inputs and cache accordingly
    inputs = expand_inputs(inputs, top_k)
    cache = expand_cache(outputs.cache, top_k, traverse_cache_fun)
    inputs = update_inputs_fun.(inputs, cache, Nx.reshape(top_k_token_ids, {:auto, 1}))

    # Step (2)
    # In the loop we make prediction for top-k continuation tokens and
    # pick the best one using the contrastive rank. From the same model
    # pass we also get the next top-k continuation tokens

    {state, _inputs, _params, _joint_hidden_state, _top_k_values} =
      while {state, inputs, params, joint_hidden_state, {top_k_scores, top_k_token_ids}},
            continue?(state.finished_length) do
        outputs = predict_fun.(params, inputs)

        hidden_state = decoder_hidden_state(outputs)

        context_hidden_state = Utils.Nx.repeat_interleave(joint_hidden_state, top_k)

        selected_idx =
          contrastive_rank(
            context_hidden_state,
            hidden_state,
            state.length,
            top_k_scores,
            penalty_alpha,
            top_k
          )

        hidden_state = Utils.Nx.chunked_take(hidden_state, top_k, selected_idx)
        joint_hidden_state = Nx.put_slice(joint_hidden_state, [0, state.length, 0], hidden_state)

        token_id = top_k_token_ids |> Nx.flatten() |> Utils.Nx.chunked_take(top_k, selected_idx)

        state = update_sequences(state, token_id, pad_token_id, eos_token_id)

        logits = outputs.logits[[.., -1]]
        logits = Utils.Nx.chunked_take(logits, top_k, selected_idx)
        logits = batch_process_logits(logits_processor_fun, logits, state)

        scores = Axon.Activations.softmax(logits, axis: -1)
        {top_k_scores, top_k_token_ids} = Nx.top_k(scores, k: top_k)

        # Mirror the selected idx to other entries within each chunk
        cache = reflect_cache(outputs.cache, top_k, selected_idx, traverse_cache_fun)
        inputs = update_inputs_fun.(inputs, cache, Nx.reshape(top_k_token_ids, {:auto, 1}))

        {state, inputs, params, joint_hidden_state, {top_k_scores, top_k_token_ids}}
      end

    state
  end

  deftransformp decoder_hidden_state(outputs) do
    hidden_states =
      case outputs do
        %{decoder_hidden_states: hidden_states} -> hidden_states
        %{hidden_states: hidden_states} -> hidden_states
      end

    elem(hidden_states, tuple_size(hidden_states) - 1)
  end

  deftransformp expand_inputs(inputs, times) do
    Map.new(inputs, fn
      {key, value} when key in ["cache"] ->
        {key, value}

      {key, %Nx.Tensor{} = value} ->
        {key, Utils.Nx.repeat_interleave(value, times)}
    end)
  end

  deftransformp expand_cache(cache, times, traverse_cache_fun) do
    traverse_cache_fun.(cache, &Utils.Nx.repeat_interleave(&1, times))
  end

  deftransformp reflect_cache(cache, times, idx, traverse_cache_fun) do
    traverse_cache_fun.(
      cache,
      &(&1
        |> Utils.Nx.chunked_take(times, idx)
        |> Utils.Nx.repeat_interleave(times))
    )
  end

  defnp contrastive_rank(
          context_hidden_state,
          hidden_state,
          length,
          top_k_scores,
          penalty_alpha,
          top_k
        ) do
    similarity_matrix =
      context_hidden_state
      |> Bumblebee.Utils.Nx.cosine_similarity(hidden_state, batched?: true)
      # hidden_state has sequence length of 1, so the batch of similarity
      # matrices has shape {batch_size * top_k, max_length, 1} and we
      # flatten out the last dimension
      |> Nx.squeeze(axes: [-1])

    # context_hidden_state includes placeholder values for tokens up
    # to max_length, so we need to ignore these
    current_sequence? = Nx.iota(Nx.shape(similarity_matrix), axis: -1) < length

    degeneration_penalty =
      current_sequence?
      |> Nx.select(similarity_matrix, Nx.Constants.neg_infinity())
      |> Nx.reduce_max(axes: [-1])

    contrastive_score =
      (1.0 - penalty_alpha) * Nx.flatten(top_k_scores) - penalty_alpha * degeneration_penalty

    contrastive_score
    |> Nx.reshape({:auto, top_k})
    |> Nx.argmax(axis: -1)
  end

  # Multinomial sampling

  defnp sampling(
          inputs,
          decoder_input_ids,
          predict_fun,
          params,
          seed,
          logits_processor_fun,
          update_inputs_fun,
          opts \\ []
        ) do
    max_length = opts[:max_length]
    pad_token_id = opts[:pad_token_id]
    eos_token_id = opts[:eos_token_id]

    state = init_sequences(decoder_input_ids, max_length, pad_token_id)

    prng_key = seed |> Nx.vectorize(:batch) |> Nx.Random.key()

    # The loop works with inputs of length 1, so if the initial input
    # is longer, we make the initial pass outside
    {state, inputs, prng_key} =
      if state.length > 1 do
        sampling_step(
          state,
          inputs,
          predict_fun,
          params,
          prng_key,
          logits_processor_fun,
          update_inputs_fun,
          pad_token_id: pad_token_id,
          eos_token_id: eos_token_id
        )
      else
        {state, inputs, prng_key}
      end

    {state, _inputs, _params, _key} =
      while {state, inputs, params, prng_key}, continue?(state.finished_length) do
        {state, inputs, prng_key} =
          sampling_step(
            state,
            inputs,
            predict_fun,
            params,
            prng_key,
            logits_processor_fun,
            update_inputs_fun,
            pad_token_id: pad_token_id,
            eos_token_id: eos_token_id
          )

        {state, inputs, params, prng_key}
      end

    state
  end

  defnp sampling_step(
          state,
          inputs,
          predict_fun,
          params,
          prng_key,
          logits_processor_fun,
          update_inputs_fun,
          opts \\ []
        ) do
    pad_token_id = opts[:pad_token_id]
    eos_token_id = opts[:eos_token_id]

    key = Nx.Random.split(prng_key)
    {key, prng_key} = {key[1], key[0]}

    outputs = predict_fun.(params, inputs)

    logits = outputs.logits[[.., -1]]
    logits = batch_process_logits(logits_processor_fun, logits, state)
    scores = Axon.Activations.softmax(logits)
    token_id = batched_choice(key, scores)

    state = update_sequences(state, token_id, pad_token_id, eos_token_id)

    inputs = update_inputs_fun.(inputs, outputs.cache, Nx.new_axis(token_id, -1))

    {state, inputs, prng_key}
  end

  deftransformp batched_choice(key, scores) do
    vocab_size = Nx.axis_size(scores, 1)

    vocab = Nx.iota({vocab_size})

    probabilities = Nx.vectorize(scores, :batch)

    {tokens, _} = Nx.Random.choice(key, vocab, probabilities, samples: 1)

    tokens
    |> Nx.squeeze()
    |> Nx.devectorize()
  end
end
